{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Save and Restore model\n",
    "\n",
    "\n",
    "\n",
    "還記得當我們在定義模型的時候，會使用 [Variables](https://www.tensorflow.org/api_docs/python/state_ops) 來儲存模型參數，而在訓練的時候則會一次又一次的更新它．這些參數在這些過程中都是儲存在記憶體裡．不過我們仍然需要在訓練過後把訓練完成的參數儲存在硬碟裡面，以便之後來使用或者對模型對進一步分析．\n",
    "\n",
    "在 Tensorflow 裡面最簡單的方法是使用 `tf.train.Saver` 這個物件 (object) 來儲存模型，在初始化這個物件的時候，裡面包含了計算圖 (graph) 中對於變量 (variables) 的許多種操作方法 (ops) ．而物件中又提供了許多方法 (method) 來執行這些操作方法 (ops)．它可以幫助我們簡單的把變量 (variables) 儲存成檔案，以提供之後使用．以下先附上兩個 class 的 api 連結，可以在裡面看到更為完整的說明．\n",
    "\n",
    "* [tf.Variable](https://www.tensorflow.org/api_docs/python/state_ops/variables#Variable) class\n",
    "* [tf.train.Saver](https://www.tensorflow.org/api_docs/python/state_ops/saving_and_restoring_variables#Saver) class\n",
    "\n",
    "## 儲存模型\n",
    "\n",
    "### Checkpoint 檔案\n",
    "\n",
    "變量 (Variables) 是被以 binary 的方式儲存成一個 checkpoint 檔 (.ckpt)，簡單的說它儲存了變量 (variable) 的名字和對應的張量 (tensor) 數值．\n",
    "\n",
    "當你建立了一個 Saver 物件以後，是可以選擇每個 variable 要用什麼名字儲存在 checkpoint 檔裡的．預設上則是會使用 [Variable.name](https://www.tensorflow.org/api_docs/python/state_ops/variables#Variable.name) 的值．\n",
    "\n",
    "為了要來了解到底儲存了什麼 variable 在 checkpoint 檔裡，你可以使用 [inpect_checkpoint](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/inspect_checkpoint.py) 中的 `print_tensors_in_checkpoint_file` 函數來檢測．\n",
    "\n",
    "### 儲存變量 (Variables)\n",
    "\n",
    "以下是一個利用 `tf.train.Saver()` 來建立物件來儲存變量 (variables) 的範例．"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import os\n",
    "os.makedirs(\"/tmp/model\")\n",
    "os.makedirs(\"/tmp/model-subset\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[array([ 0.1,  0.1], dtype=float32), array([ 0.30000001,  0.30000001], dtype=float32)]\n"
     ]
    }
   ],
   "source": [
    "# 建立一些變數以及對應的名字．\n",
    "\n",
    "v1 = tf.Variable([0.1, 0.1], name=\"v1\")\n",
    "v2 = tf.Variable([0.2, 0.2], name=\"v2\")\n",
    "\n",
    "# 建立所有 variables 的初始化 ops\n",
    "init_op = tf.global_variables_initializer()\n",
    "\n",
    "# 建立 saver 物件\n",
    "saver = tf.train.Saver()\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    \n",
    "    # 執行初始化\n",
    "    sess.run(init_op)\n",
    "    \n",
    "    #重新指定 v2 的值\n",
    "    ops = tf.assign(v2, [0.3, 0.3])\n",
    "    sess.run(ops)\n",
    "    \n",
    "    print sess.run(tf.global_variables())\n",
    "    # ... 中間略去許多模型定義以及訓練，例如可以是 MNIST 的定義以及訓練\n",
    "    \n",
    "    save_path = saver.save(sess, \"/tmp/model/model.ckpt\") # 儲存模型到 /tmp/model.ckpt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "v1 (DT_FLOAT) [2]\r\n",
      "v2 (DT_FLOAT) [2]\r\n",
      "\r\n"
     ]
    }
   ],
   "source": [
    "# 使用 inspect_checkpoint.py tool 來印出 ckpt 檔\n",
    "!python /usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/inspect_checkpoint.py --file_name=/tmp/model/model.ckpt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 儲存特定變量 (Variables)\n",
    "\n",
    "如果我們沒有指定特定的參數給 `tf.train.Saver()`，它會自動把計算圖裡的所有變量都儲存進去．\n",
    "\n",
    "但有時候在儲存的時候我們會想要指定不一樣的名字，例如建立了一個變量叫做 \"weights\" 但是在儲存的時候我們會希望他的名字叫做 \"params\"\n",
    "\n",
    "也有些時候我們只想儲存部分的變量．例如訓練了一個 5 層的神經網路，但是我們只需要前 4 層的變量．\n",
    "\n",
    "我們可以很簡單地在 `Python dictionary` 裡指定變量以及對應的名字來給 `tf.train.Saver()` 當參數，就可以完成上面的目標了．範例如下\n",
    "\n",
    "註: \n",
    "- 你可以建立許多個 `saver` 物件，並且在這些物件之間，儲存或者載入部分的模型變量．同樣的變量會在不同的 `saver` 物件中存在，而且只有在執行 `restore()` 的時候才會被改變．\n",
    "- 如果你只在一個 session 裡載入了一個模型中一部分的變量，你必須要先用 `tf.global_variables_initializer()` 的方法先把其他的沒有載入的變量先初始化起來．"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "v1 = tf.Variable([0.1, 0.1], name=\"v1\")\n",
    "v2 = tf.Variable([0.4, 0.4], name=\"v2\")\n",
    "init_ops = tf.global_variables_initializer()\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    sess.run(init_ops)\n",
    "    saver = tf.train.Saver({\"my_v2\": v2})\n",
    "    save_path = saver.save(sess, \"/tmp/model-subset/model.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "my_v2 (DT_FLOAT) [2]\r\n",
      "\r\n"
     ]
    }
   ],
   "source": [
    "# 使用 inspect_checkpoint.py tool 來印出 ckpt 檔\n",
    "!python /usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/inspect_checkpoint.py --file_name=/tmp/model-subset/model.ckpt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 載入模型\n",
    "\n",
    "### 1. 預先定義好模型\n",
    "第一種方法來載入模型的時候，你必須先定義好原本的計算圖 (模型)，例如這裡我們先定義了 `v1` 和 `v2` 然後用 `sess.restore` 的方法載入之前被更改過的值．\n",
    "同樣的 Saver 物件也是可以拿來載入變量 (variables)．值得注意的是當要從檔案中載入變量 (variables) 的時候是不用先把它初始化的．\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model restored.\n",
      "all values [array([ 0.1,  0.1], dtype=float32), array([ 0.30000001,  0.30000001], dtype=float32)]\n",
      "v2 value : [ 0.30000001  0.30000001]\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "v1 = tf.Variable(tf.constant(0.1, shape = [2]), name=\"v1\")\n",
    "v2 = tf.Variable(tf.constant(0.2, shape = [2]), name=\"v2\")\n",
    "\n",
    "saver = tf.train.Saver()\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    saver.restore(sess, \"/tmp/model/model.ckpt\")\n",
    "    print(\"Model restored.\")\n",
    "    print(\"all values %s\" % sess.run(tf.global_variables()))\n",
    "    print(\"v2 value : %s\" % sess.run(v2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 載入 meta\n",
    "\n",
    "另一種方法則是在 `0.11.0RC1` 之後可以用載入 `meta` 的方式，就不用預先定義計算圖了．但是需要從重新拿取每一個參數的名字，以下就是範例．"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "v1:0 with value [ 0.1  0.1]\n",
      "v2:0 with value [ 0.40000001  0.40000001]\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "with tf.Session() as sess:\n",
    "    saver = tf.train.import_meta_graph('/tmp/model-subset/model.ckpt.meta')\n",
    "    saver.restore(sess, tf.train.latest_checkpoint('/tmp/model-subset/'))\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    all_vars = tf.trainable_variables()\n",
    "    for v in all_vars:\n",
    "        print(\"%s with value %s\" % (v.name, sess.run(v)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 學習資源連結\n",
    "- [[tensorflow 官方文件] variables](https://www.tensorflow.org/how_tos/variables/)\n",
    "- [[stackoverflow] How to restore model](http://stackoverflow.com/questions/33759623/tensorflow-how-to-restore-a-previously-saved-model-python)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
